#################################################################################
# Copyright (c) 2018-2021, Texas Instruments Incorporated - http://www.ti.com
# All Rights Reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# * Redistributions of source code must retain the above copyright notice, this
#   list of conditions and the following disclaimer.
#
# * Redistributions in binary form must reproduce the above copyright notice,
#   this list of conditions and the following disclaimer in the documentation
#   and/or other materials provided with the distribution.
#
# * Neither the name of the copyright holder nor the names of its
#   contributors may be used to endorse or promote products derived from
#   this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################

from .pixel2pixelnet import *

###########################################
class MultiScalePixel2PixelNet(torch.nn.Module):
    def __init__(self, base_model, decoder_class, model_config):
        super().__init__()
        self.output_channels = model_config.output_channels
        #output_channels_merged = [sum(model_config.output_channels)]

        self.output_type = model_config.output_type
        self.mode = 'bilinear'

        assert model_config.num_resolutions==2, 'num_resolutions must be == 2'

        if model_config.num_decoders != 1:
            xnn.utils.print_once('num_decoders is being ignored and set to 1')

        # change output_type to avoid argmax for the time being
        # output_types = ['for_segmentation' if (otype=='segmentation') else otype for otype in model_config.output_type]

        self.scales = torch.nn.ModuleList()
        for res_idx in range(model_config.num_resolutions):
            self.scales.append(Pixel2PixelNet(base_model, decoder_class, model_config))

        self.scalefuse = xnn.layers.AddBlock()

        # disable prediction in individual scales and add it here.
        #decoder_channels = self.scales[0].decoders[0].decoder_channels
        #self.pred = xnn.layers.ConvDWSepNormAct2d(decoder_channels, output_channels_merged[0], kernel_size=3,
        #                                       normalization=(True, False), activation=(False, False), groups=1)
        # to enable weight loading
        #self.scales[0].decoders[0].pred = self.pred

        self._initialize_weights()


    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, torch.nn.Conv2d):
                torch.nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, torch.nn.BatchNorm2d):
                if m.weight is not None:
                    torch.nn.init.constant_(m.weight, 1.0-(1e-5))
                if m.bias is not None:
                    torch.nn.init.constant_(m.bias, 0)
            elif isinstance(m, torch.nn.Linear):
                m.weight.data.normal_(0, 0.01)
                m.bias.data.zero_()


    def forward(self, x_inp):
        assert len(x_inp) == 1, 'only signle input is supported'

        x0 = x_inp[0]
        #in_shape = x0.shape

        # resized input
        x1 = F.interpolate(x0, scale_factor=0.5, mode=self.mode)
        y1 = self.scales[1]([x1])[0]
        y1 = F.interpolate(y1, scale_factor=2.0, mode=self.mode)

        # original resolution, but cropped input
        pad_x0_v = int((x0.size(2) - x1.size(2)) // 2)
        pad_x0_h = int((x0.size(3) - x1.size(3)) // 2)
        x0_crop = x0[:,:,pad_x0_v:-pad_x0_v,pad_x0_h:-pad_x0_h]
        y0 = self.scales[0]([x0_crop])[0]
        pad_y0_v = int((y1.size(2) - y0.size(2)) // 2)
        pad_y0_h = int((y1.size(3) - y0.size(3)) // 2)
        y0 = F.pad(y0, pad=(pad_y0_h, pad_y0_h, pad_y0_v, pad_y0_v), value=0.0)

        # full output
        y = self.scalefuse([y0,y1])

        ## prediction
        ##y = self.pred(y)
        #
        #scale_factor = (in_shape[2] / y.shape[2], in_shape[3] / y.shape[3])
        #y = F.interpolate(y, scale_factor=scale_factor, mode=self.mode)
        #
        ## create various task outputs
        #y = split_output_channels(y, self.output_channels)

        for o_idx, o_type in enumerate(self.output_type):
            if (not self.training) and (o_type == 'segmentation'):
                y[o_idx] = torch.argmax(y[o_idx], dim=1, keepdim=True)

        return y


